package database

import (
	"context"
	"errors"
	"fmt"
	"gitee.com/zhucheer/orange/cfg"
	"gitee.com/zhucheer/orange/logger"
	pgo "github.com/go-pg/pg/v10"
	"github.com/zhuCheer/pool"
	"sync"
	"time"
)

var postgreConn *PostgreDB

type PostgreDB struct {
	connPool map[string]pool.Pool
	count    int
	lock     sync.Mutex
}

// NewPostgre 初始化Postgre连接
func NewPostgre() DataBase {
	if postgreConn != nil {
		return postgreConn
	}

	postgreConn = &PostgreDB{
		connPool: make(map[string]pool.Pool, 0),
	}
	return postgreConn
}

// 注册所有已配置的 postgre
func (pg *PostgreDB) RegisterAll() {
	config := cfg.Config.GetMap("database.postgre")

	pg.count = len(config)
	for dd := range config {
		pg.Register(dd)
	}
}

// Register 注册一个postgre配置
func (pg *PostgreDB) Register(name string) {
	connUrl := cfg.Config.GetString("database.postgre." + name + ".url")
	addr := cfg.Config.GetString("database.postgre." + name + ".addr")
	username := cfg.Config.GetString("database.postgre." + name + ".username")
	password := cfg.Config.GetString("database.postgre." + name + ".password")
	dbname := cfg.Config.GetString("database.postgre." + name + ".dbname")
	connTimeout := cfg.GetInt("database.postgre."+name+".timeout", 5)
	sslMode := cfg.GetString("database.postgre."+name+".sslMode", "disable")

	initCap := getDBIntConfig("postgre", name, "initCap")
	maxCap := getDBIntConfig("postgre", name, "maxCap")
	idleTimeout := getDBIntConfig("postgre", name, "idleTimeout")

	dsnPath := fmt.Sprintf("postgresql://%s:%s@%s/%s?connect_timeout=%d&sslmode=%s", username, password, addr, dbname, connTimeout, sslMode)
	if connUrl != "" {
		dsnPath = connUrl
	}
	opt, err := pgo.ParseURL(dsnPath)
	if err != nil {
		panic(err)
	}

	// connPostgre 建立连接
	connPostgre := func() (interface{}, error) {
		conn := pgo.Connect(opt)
		if conn == nil {
			return nil, errors.New("pgsql connect error")
		}
		return conn, err
	}

	// closePostgre 关闭连接
	closePostgre := func(v interface{}) error {
		v.(*pgo.DB).Close()
		return nil
	}

	// pingPostgre 检测连接连通性
	pingPostgre := func(v interface{}) error {
		conn := v.(*pgo.DB)
		return conn.Ping(context.Background())
	}

	// 创建一个连接池
	p, err := pool.NewChannelPool(&pool.Config{
		InitialCap: initCap,
		MaxCap:     maxCap,
		Factory:    connPostgre,
		Close:      closePostgre,
		Ping:       pingPostgre,
		//连接最大空闲时间，超过该时间的连接 将会关闭，可避免空闲时连接EOF，自动失效的问题
		IdleTimeout: time.Duration(idleTimeout) * time.Second,
	})
	if err != nil {
		logger.Error("register postgre conn [%s] error:%v", name, err)
		return
	}
	pg.insertPool(name, p)
}

// insertPool 将连接池插入map
func (pg *PostgreDB) insertPool(name string, p pool.Pool) {
	if pg.connPool == nil {
		pg.connPool = make(map[string]pool.Pool, 0)
	}

	pg.lock.Lock()
	defer pg.lock.Unlock()
	pg.connPool[name] = p

}

// getDB 从连接池获取一个连接
func (pg *PostgreDB) getDB(name string) (conn interface{}, put func(), err error) {
	put = func() {}

	if _, ok := pg.connPool[name]; !ok {
		return nil, put, errors.New("no postgreSQL connect")
	}

	conn, err = pg.connPool[name].Get()
	if err != nil {
		return nil, put, errors.New(fmt.Sprintf("postgreSQL get connect err:%v", err))
	}

	put = func() {
		pg.connPool[name].Put(conn)
	}

	return conn, put, nil
}

// putDB 将连接放回连接池
func (pg *PostgreDB) putDB(name string, db interface{}) (err error) {
	if _, ok := pg.connPool[name]; !ok {
		return errors.New("no postgreSQL connect")
	}
	err = pg.connPool[name].Put(db)

	return
}

//  GetPostgre 获取一个GetPostgre连接
func GetPostgre(name string) (db *pgo.DB, put func(), err error) {
	put = func() {}
	if postgreConn == nil {
		return nil, put, errors.New("db connect is nil")
	}

	conn, put, err := postgreConn.getDB(name)
	if err != nil {
		return nil, put, err
	}
	db = conn.(*pgo.DB)
	return db, put, nil
}
